# coding=utf-8

""" ReservoirTransformer model configuration"""
""" Author: Md Kowsher"""
from collections import OrderedDict
from typing import Mapping

from transformers import PretrainedConfig
from transformers.onnx import OnnxConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)




class ReservoirTConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ReservoirTModel`]. It is used to
    instantiate a ReservoirTTimeSeries model according to the specified arguments, defining the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:

        hidden_size (`int`, *optional*, defaults to 16):
            Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 4):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 4):
            Number of attention heads for each attention layer in the Transformer encoder.
        intermediate_size (`int`, *optional*, defaults to 64):
            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.


        max_sequence_length (`int`, *optional*, defaults to 500):
            The maximum sequence lenght.
        sequence_length (`int`, *optional*, defaults to 12):
            The  sequence lenght of input which is look-back windows to capture the previous history.
        output_size (`int`, *optional*, defaults to None):
            The output dimension of prediction value. In general for mulitvariate-time series, we use all feature to predict.
        re_output_size (`int`, *optional*, defaults to 4):
            The reservoir output dimension.
        pred_len (`int`, *optional*, defaults to 720):
            The multivaraite horizons to predict.


        num_reservoirs (`int`, *optional*, defaults to 10):
            The reservoirs for ensembelling (group reservoir)
        reservoir_size (`List[int]`, *optional*, defaults to [30, 15, 20, 25, 30, 35, 40, 45, 50, 50]):
            The  reservoir sizes of group reservoir
        spectral_radius (`List[float]`, *optional*, defaults to [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05]):
            The spectral radius of each reservoir in group reservoir
        sparsity (`List[float]`, *optional*, defaults to [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15]):
            The sparsity rate in each reservoir in group reservoir
        leaky (`List[float]`, *optional*, defaults to [0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39]):
            The leaky rate in each reservoir in group reservoir




        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
            The dropout ratio for the attention probabilities.
        max_position_embeddings (`int`, *optional*, defaults to 512):
            The maximum sequence length that this model might ever be used with. Typically set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        type_vocab_size (`int`, *optional*, defaults to 2):
            The vocabulary size [mask or non_mask here] of the `token_type_ids` passed when calling [`ReservoirTModel`] .
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
            The epsilon used by the layer normalization layers.
        position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
            Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
            positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
            [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
            For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
            with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
        is_decoder (`bool`, *optional*, defaults to `False`):
            Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        decoder_dropout (`float`, *optional*):
            The dropout ratio for the classification or regression head.
        problem_type ('str', *optional*):
            Type of problem such as 'regression', 'single_label_classification', 'multi_label_classification'


    Examples:

    ```python
    >>> from configuration import ReservoirTConfig

    >>> # Initializing a trnasformer style configuration
    >>> configuration = ReservoirTConfig()

    >>> # Initializing a model (with random weights) from trnasformer style configuration
    >>> model = ReservoirTTimeSeries(config = configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    model_type = "ReservoirTransformer"

    def __init__(
        self,
        hidden_size=8,
        output_size=None,
        re_output_size=4,
        num_hidden_layers=4,
        pred_len=720,
        num_attention_heads=4,
        intermediate_size=64,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_sequence_length=500,
        sequence_length=12,
        type_vocab_size=2,
        num_reservoirs=20,
        reservoir_size = [100, 105, 110, 115, 120, 125, 130,  135, 140, 145,100, 105, 110, 115, 120, 125, 130,  135, 140, 145],
        spectral_radius = [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05,0.85, 0.86, 0.87, 0.88, 0.89, 0.90, 0.91, 0.92, 0.93, 0.94],
        sparsity = [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, 0.85, 0.86, 0.87, 0.88, 0.89, 0.90, 0.91, 0.92, 0.93, 0.94],
        leaky = [0.075, 0.15, 0.225, 0.30, 0.375, 0.45, 0.525, 0.6, 0.675, 0.75,0.075, 0.15, 0.225, 0.30, 0.375, 0.45, 0.525, 0.6, 0.675, 0.75],
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        position_embedding_type="absolute",
        use_cache=True,
        decoder_dropout=None,
        problem_type=None,
        soft_border=8,
        batch=64,
        train_size=0.7,
        val_size=0.1,
        test_size=0.2,
        scaling=True,



        #regressor_dropout=None,

        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)

        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.position_embedding_type = position_embedding_type
        self.use_cache = use_cache
        self.decoder_dropout = decoder_dropout
        self.output_size = output_size
        self.re_output_size = re_output_size
        self.pred_len = pred_len
        self.max_sequence_length = max_sequence_length
        self.sequence_length = sequence_length
        self.problem_type = problem_type
        self.num_reservoirs = num_reservoirs
        self.spectral_radius = spectral_radius
        self.sparsity = sparsity
        self.reservoir_size = reservoir_size
        self.leaky = leaky
        self.soft_border=soft_border
        self.batch=batch
        self.train_size=train_size
        self.val_size=val_size
        self.test_size=test_size
        self.scaling=scaling


class ReservoirTOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            dynamic_axis = {0: "batch", 1: "sequence"}
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
                ("token_type_ids", dynamic_axis),
            ]
        )
    

import os
import torch
import torch.nn as nn
import networkx as nx
import numpy as np
from scipy.sparse.linalg import eigs
#os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import PatchTSTModel, PatchTSTConfig, TrainingArguments, EarlyStoppingCallback, Trainer, PatchTSTForPrediction, PatchTSMixerConfig, PatchTSMixerForPrediction
#from reservoir_computing.modules import RC_model
#from configuration import ReservoirTConfig
from tqdm import tqdm
from datasets import Dataset
import wandb
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
import numpy as np
from datetime import datetime
from scipy.linalg import eigvals
import statistics


wandb.init(
    # set the wandb project where this run will be logged
    project="EchoSolo",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.002,
    "epochs": 100,
    }
)

#换数据集一定要记得换output_size和embedding_size

configuration = ReservoirTConfig()

configuration.output_size=140
configuration.re_output_size=21
configuration.max_sequence_length=1000
configuration.sequence_length=336
configuration.pred_len=720
configuration.hidden_size=7
configuration.num_attention_heads=7
configuration.hidden_dropout_prob=0.1
configuration.num_hidden_layers=16
configuration.num_reservoirs = 30
configuration.intermediate_size=128
configuration.reservoir_size = [105, 110, 115, 120, 125, 130, 135,  140, 145, 150,105, 110, 115, 120, 125, 130, 135,  140, 145, 150,105, 110, 115, 120, 125, 130, 135,  140, 145, 150]
configuration.spectral_radius = [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05,0.85, 0.86, 0.87, 0.88, 0.89, 0.90, 0.91, 0.92, 0.93, 0.94,0.85, 0.86, 0.87, 0.88, 0.89, 0.90, 0.91, 0.92, 0.93, 0.94]
configuration.sparsity = [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, 0.85, 0.86, 0.87, 0.88, 0.89, 0.90, 0.91, 0.92, 0.93, 0.94,0.85, 0.86, 0.87, 0.88, 0.89, 0.90, 0.91, 0.92, 0.93, 0.94]
configuration.leaky = [0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59,0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59,0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59]
configuration.activation_function = ["tanh","sigmoid","relu","tanh","sigmoid","relu","tanh","sigmoid","relu","tanh","sigmoid","relu","tanh","sigmoid","relu","tanh","sigmoid","relu","tanh","sigmoid","tanh","sigmoid","relu","tanh","sigmoid","relu","tanh","sigmoid","relu","tanh"]
configuration.period_type = "hour"
#configuration.reservoir_size = 1000
configuration.attention_probs_dropout_prob=0.0
configuration.batch_size = 16
configuration.embedding_size = 140
configuration.embedding_type = 2
configuration.num_heads = 7
if configuration.period_type == 'minute':
    configuration.period_feature_num = 6
elif configuration.period_type == 'hour':
    configuration.period_feature_num = 4
elif configuration.period_type == 'day':
    configuration.period_feature_num = 4


class TimeSeriesEmbedding(nn.Module):
    def __init__(self,config):

        super(TimeSeriesEmbedding, self).__init__()
        self.hidden_size = config.hidden_size
        self.period_feature_num = config.period_feature_num
        self.embedding_size = config.embedding_size
        self.embedding_type = config.embedding_type
        self.sequence_length = config.sequence_length
        self.feature_as_token_each_feature_emb_size = int(config.embedding_size/(config.hidden_size))
        self.query = nn.Linear(self.hidden_size,self.embedding_size)
        self.key = nn.Linear(self.hidden_size,self.embedding_size)
        self.value = nn.Linear(self.hidden_size,self.embedding_size)
        self.multihead_attn = nn.MultiheadAttention(self.embedding_size, num_heads=config.hidden_size,batch_first=True)
        self.batch_size = config.batch_size
        self.feature_as_token_weights = nn.ModuleList([nn.Linear(1, self.feature_as_token_each_feature_emb_size) for _ in range(config.hidden_size)])



    def forward(self,input_ids,key_values_input_ids = None):
        input_ids = input_ids.float()
        if self.embedding_type == 1:
            query = self.query(input_ids)
            if key_values_input_ids is not None:
                key = self.key(key_values_input_ids)
                value = self.key(key_values_input_ids)
            else:
                key = self.key(input_ids)
                value = self.key(input_ids)
            attn_output, attn_weights = self.multihead_attn(query, key, value)

            return attn_output

        if self.embedding_type == 2:
            fl_inputs_embeds_list = []
            for i in range(self.hidden_size):
                input_features_seq_scale = input_ids[:,:,i] #(sample_size, time_length)
                input_features_seq = input_features_seq_scale.unsqueeze(-1)
                input_featrues_embeds = self.feature_as_token_weights[i](input_features_seq) #(sample_size,time_length,num_features)
                fl_inputs_embeds_list.append(input_featrues_embeds * input_features_seq_scale.unsqueeze(-1))  # Broadcasting to match shape



            fl_input_embeds = torch.cat(fl_inputs_embeds_list, dim=-1)  # Shape: (batch_size, seq_length, total_embed_dim)
            return fl_input_embeds


class DeepReservoirNet(nn.Module):
    def __init__(self, config, reservoir_size=1000, spectral_radius=0.9, leaky=0.3, sparsity=0.5, activation_function="tanh",return_entropy = False):
        super(DeepReservoirNet, self).__init__()

        self.input_size = config.sequence_length
        self.reservoir_size = reservoir_size
        self.output_size = config.re_output_size
        self.spectral_radius = spectral_radius
        self.leaky = leaky

        self.W_in = nn.Linear(self.input_size, reservoir_size, bias=False).float()
        self.W_in.weight.requires_grad = False
        self.W_res = nn.Linear(reservoir_size, reservoir_size, bias=False).float()
        #self.W_res = self.create_ws_reservoir(reservoir_size, sparsity, spectral_radius)
        self.W_res.weight.requires_grad = False
        #self.W_out = nn.Linear(reservoir_size, self.output_size).float()
        #self.W_out.weight.requires_grad = False
        self.res_state = torch.zeros(1, reservoir_size).float()
        self.layer_norm = nn.LayerNorm(reservoir_size)
        self.act= nn.Tanh()
        self.return_entropy = return_entropy

        self.W_res_norm = self.compute_spectral_radius(sparsity)
        self.self_attention = nn.MultiheadAttention(self.output_size, config.num_attention_heads, dropout=0.2)
        if activation_function == "tanh":
            self.act = nn.Tanh()
        elif activation_function == "relu":
            self.act = nn.ReLU()
        elif activation_function == "sigmoid":
            self.act = nn.Sigmoid()


    def create_ws_reservoir(self, N, sparsity, target_radius):
        """使用NetworkX创建WS小世界储备池权重"""
        # 1. 计算平均度数K以实现目标稀疏度
        avg_degree = max(2, int((1 - sparsity) * (N - 1)))

        # 2. 生成小世界图
        ws_graph = nx.watts_strogatz_graph(
            n=N,
            k=avg_degree,
            p=0.1  # 重连概率
        )

        # 3. 转换为邻接矩阵
        adj_matrix = nx.to_numpy_array(ws_graph)

        # 4. 添加随机权重
        weight_matrix = adj_matrix * np.random.normal(0, 1, size=(N, N))

        # 5. 缩放谱半径
        eigvals_ = eigvals(weight_matrix)
        current_radius = np.max(np.abs(eigvals_))
        weight_matrix *= target_radius / current_radius

        # 转换为PyTorch张量
        return torch.tensor(weight_matrix, dtype=torch.float32, requires_grad=False)

    def compute_spectral_radius(self, sparsity=0.5):
        with torch.no_grad():
            self.W_res.weight.data = torch.randn(self.reservoir_size, self.reservoir_size)
            # set a fraction of the entries to zero
            num_zeros = int(sparsity * self.reservoir_size ** 2)
            idxs = torch.randperm(self.reservoir_size ** 2)[:num_zeros]
            self.W_res.weight.data.view(-1)[idxs] = 0

            eigenvals = torch.linalg.eigvals(self.W_res.weight)
            radius = torch.max(torch.abs(eigenvals))
            self.W_res.weight.data /= radius
        return radius
    @staticmethod
    def entropy_of_values(value_tensor, num_bins=100):

        #value_tensor = torch.sigmoid(value_tensor)

        bins = torch.linspace(0, 1, steps=num_bins+1)

        indices = torch.bucketize(value_tensor, bins, right=True) - 1 

        indices = torch.clamp(indices, 0, num_bins-1)


        counts = torch.bincount(indices, minlength=num_bins).float()

        probabilities = counts / torch.sum(counts)


        eps = 1e-10
        entropy = -torch.sum(probabilities * torch.log2(probabilities + eps))

        return entropy

    def forward(self, input_data, res_state):
        #print()
        # Compute reservoir state
        outputs = []
        #if res_state == None:
        #   res_state = self.res_state.clone()

        batch_size = input_data.shape[0]
        hidden_size = input_data.shape[2]
        input_data = input_data.permute(0, 2, 1)
        for t in range(batch_size):

            i_data = input_data[t]

            #print("i_data", i_data.shape)
            input_proj = self.W_in(i_data.float())

            res_proj = self.W_res(res_state)
            #res_proj = torch.mm(res_state, self.W_res.t())
            # print('res_state', res_state.shape)
            #print('input_proj', input_proj.shape)
            #print('res_proj', res_proj.shape)
            if self.act == "relu":
                middle_proj = input_proj + res_proj
                middle_proj = self.layer_norm(middle_proj)
                res_state = (1 - self.leaky) * res_state + self.leaky * torch.clip(self.act(middle_proj),min=-1,max=1)
            elif self.act == "sigmoid":
                middle_proj = input_proj + res_proj
                res_state = (1 - self.leaky) * res_state + self.leaky * (2*self.act(middle_proj)-1)
            else:
                res_state = (1 - self.leaky) * res_state + self.leaky * self.act(input_proj + res_proj)
            #print('fres_state', res_state.shape)
            #print( (1 - self.leaky), (0.2*res_state).shape)
            # Normalize reservoir state
            res_state = res_state / self.W_res_norm
            #print('here-1',res_state.shape )


            # Compute output
            # output = self.W_out(res_state)
            #print('ddd',output.shape)
            # Permute output to shape (sequence_length, batch_size, output_size)

            #output, self_attention_weights = self.self_attention(output, output, output)
            # Permute output back to shape (batch_size, sequence_length, output_size)
            #print("output.shape")
            #print("res_state shape:",res_state.shape)

            outputs.append(res_state.squeeze(0))
            #print("outputs lengt:", output)
        final_output = torch.stack(outputs, dim=0)
        #print("hidden_size",hidden_size)
        #print("reservoir_output",final_output)
        if self.return_entropy == True:
            entropy_list = []
            for i in range(hidden_size):
                shan_entropy = self.entropy_of_values(torch.tensor(final_output[-1,i,:]))
                entropy_list.append(float(shan_entropy))
            #print(entropy_list[0])
            shannon_entropy = statistics.mean(entropy_list)

        if self.return_entropy == True:
            return {'Output':final_output, "State": res_state, "Shannon_entropy":shannon_entropy}

        else:
            return {'Output':final_output, "State": res_state}



class ReservoirTTimeSeries(nn.Module):
        # Initialize weights and apply final processing
        #self.post_init()
    def __init__(self, config,return_entropy = True):
        super().__init__()
        self.num_labels = config.num_labels
        self.config = config

        #self.bert_enc = BertGenerationEncoder(config)
        #self.bert_dec = BertGenerationDecoder(config)

        self.layer_norm = nn.LayerNorm(config.hidden_size)

        self.reservoirs=nn.ModuleList()
        self.id_train = None
        self.id_test = None
        self.reservoir_state = None
        self.state_ids = None
        self.return_entropy = return_entropy
        for i in range(config.num_reservoirs):
            reservoir = DeepReservoirNet(config=config,
                                         reservoir_size=config.reservoir_size[i],
                                         spectral_radius=config.spectral_radius[i],
                                         leaky=config.leaky[i],
                                         sparsity=config.sparsity[i],
                                         activation_function = config.activation_function[i],
                                         return_entropy = self.return_entropy)

            self.reservoirs.append(reservoir)


    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        x_marks: Optional[torch.Tensor] = None,
        y_marks: Optional[torch.Tensor] = None,
        reservoir_ids: Optional[torch.Tensor] = None,
        state_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        dataset_type = None,
        train_dataset = None,
        eval_dataset = None,
        id = "id_train",
        return_entropy = False,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        #print(id)

        if dataset_type == "eval_dataset":
            sample_size,_,_ = inputs_embeds.shape
            inputs_embeds = torch.cat((train_dataset["inputs_embeds"], inputs_embeds), dim=0)
        elif dataset_type == "test_dataset":
            sample_size,_,_ = inputs_embeds.shape
            inputs_embeds = torch.cat((train_dataset["inputs_embeds"],
                                       eval_dataset["inputs_embeds"],
                                       inputs_embeds), dim=0)

        if reservoir_ids is None:
            # Zero-pad the tensor in front
            padded_tensor = F.pad(inputs_embeds, (0, 0, 0, 0, 1, 0))  # (left, right, top, bottom) padding

            # Remove the last row
            reservoir_ids = padded_tensor[:-1]


        # Zero pad in front to make it (8, 8, 4)

        #print("reservoir_ids", reservoir_ids.shape, inputs_embeds.shape)
        #print("reservoir_ids",reservoir_ids)

        state_ids = [torch.zeros(self.config.hidden_size, self.config.reservoir_size[i]).float() for i in range(self.config.num_reservoirs)]
        reservoir_outputs=[]
        shannon_entropy = []
        for i, reservoir in tqdm(enumerate(self.reservoirs)):
            print("Calculating Entropy......")

            reservoir_output = reservoir(reservoir_ids.float(), state_ids[i].to(inputs_embeds.device))
            output_re = reservoir_output['Output']
            res_state = reservoir_output['State']
            entropy = reservoir_output['Shannon_entropy']
            shannon_entropy.append(entropy)
            state_ids[i] = res_state

            reservoir_outputs.append(output_re)
            print("Shannon entropy for ESN #"+str(i)+":",entropy)
            #if reservoir_outputs is not None:
            #    reservoir_outputs = torch.cat((reservoir_outputs,output_re), dim = 1)
            #else:
            #    reservoir_outputs = output_re

        #reservoir_outputs = reservoir_outputs/self.config.num_reservoirs
        # Transpose the lists
        shan_entropy = statistics.mean(shannon_entropy)
        print("General Shannon Entropy is",shan_entropy)
        transposed = list(zip(*reservoir_outputs))

        # Convert each tuple to a list (optional)
        reservoir_outputs = [list(tup) for tup in transposed]

        if dataset_type is None:
            return {"inputs_embeds":inputs_embeds,
                    "reservoir_outputs":reservoir_outputs,
                    "x_marks":x_marks,
                    "y_marks":y_marks,
                    "labels_ids":labels_ids}
        elif dataset_type == "eval_dataset":
            return {"inputs_embeds":inputs_embeds[-sample_size:],
                    "reservoir_outputs":reservoir_outputs[-sample_size:],
                    "x_marks":x_marks[-sample_size:],
                    "y_marks":y_marks[-sample_size:],
                    "labels_ids":labels_ids}
        elif dataset_type == "test_dataset":
            return {"inputs_embeds":inputs_embeds[-sample_size:],
                    "reservoir_outputs":reservoir_outputs[-sample_size:],
                    "x_marks":x_marks[-sample_size:],
                    "y_marks":y_marks[-sample_size:],
                    "labels_ids":labels_ids}



class Reservoir_fl_model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_labels = config.num_labels
        self.config = config
        self.hidden_size = config.hidden_size
        self.re_output_size = config.re_output_size
        self.sequence_length = config.sequence_length
        self.batch_size = config.batch_size
        self.reservoir = ReservoirTTimeSeries(config)
        self.num_res = config.num_reservoirs
        #self.input_projection = nn.Linear(self.hidden_size, 7)
        #self.label_projection = nn.Linear(self.hidden_size, 7)
        #self.output_projection = nn.Linear(7, self.hidden_size)
        patchTST_config = PatchTSTConfig(prediction_length=720,
                                       num_input_channels=int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size,
                                       context_length = 336,
                                       num_attention_heads=16,
                                       patch_length = 16,
                                       patch_stride= 8,
                                       dropout=0.2,
                                       d_model=128,
                                       ffn_dim=256,
                                       head_dropout=0.2,
                                       scaling="std",
                                       pre_norm=True,
                                       norm_type="layernorm",
                                       channel_attention=False,
                                       random_mask_ratio=0.4,
                                       )
        self.patchtst = PatchTSTForPrediction(patchTST_config)
        self.norm1 = nn.LayerNorm(int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size)
        self.norm2 = nn.LayerNorm(self.config.pred_len)
        self.cross_attn_layers = 3
        self.period_type = config.period_type

        if self.period_type == 'minute':
            self.period_feature_num = 6
        elif self.period_type == 'hour':
            self.period_feature_num = 4
        elif self.period_type == 'day':
            self.period_feature_num = 4

        self.feed_forward = nn.Sequential(
            # 使用修复后的序列调整层
            self.SequenceAdjuster(
                input_seq_len=self.sequence_length,  # 336
                output_seq_len=self.config.pred_len,  # 720
                feature_dim=int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size)
            ),
            nn.ReLU(),
            nn.Dropout(0.2),
            # 确保输入特征匹配
            nn.Linear(
                int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size),
                int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size)),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(
                int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size),
                int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size)),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(
                int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size),
                self.config.hidden_size  # 应该是7
            ),

        )
        self.dropout = nn.Dropout(0.3)
        self.context = nn.Linear(self.config.embedding_size, self.config.embedding_size)
        self.score_mlp = nn.Sequential(
            nn.Linear(self.config.embedding_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

        self.EmbeddingModel  = TimeSeriesEmbedding(self.config)
        self.ReservoirModel = ReservoirTTimeSeries(self.config)
        self.self_attention_0 = nn.MultiheadAttention(embed_dim=self.config.embedding_size,num_heads=config.hidden_size,batch_first=True,dropout=0.2)
        self.self_attention_1 = nn.MultiheadAttention(embed_dim=self.config.embedding_size,num_heads=config.hidden_size,batch_first=True,dropout=0.2)
        self.self_attnetion_layer_norm_0 = nn.LayerNorm(self.config.embedding_size)
        self.self_attnetion_layer_norm_1 = nn.LayerNorm(self.config.embedding_size)
        self.cross_attns = nn.ModuleList()
        self.cross_attn_norms = nn.ModuleList()
        self.time_feature_encoder = self.Time_feature_encoder(self.period_type)
        #self.shape_shift = nn.Linear(int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size+self.period_feature_num,int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size)
        for _ in range(self.cross_attn_layers):
            self.cross_attns.append(
                nn.MultiheadAttention(
                    embed_dim=int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size),
                    kdim=self.config.embedding_size,
                    vdim=self.config.embedding_size,
                    num_heads=config.hidden_size,
                    batch_first=True,
                    dropout=0.2
                )
            )
            self.cross_attn_norms.append(
                nn.LayerNorm(int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size))
            )
        #self.crossattn = nn.MultiheadAttention(int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size,
        #                                       kdim =self.config.embedding_size,
        #                                       vdim= self.config.embedding_size,
        #                                       num_heads=7,
        #                                       batch_first=True,
        #                                       dropout= 0.2)
        #self.crossattn_1 = nn.MultiheadAttention(int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size,
        #                                         kdim =self.config.embedding_size,
        #                                         vdim= self.config.embedding_size,
        #                                         num_heads=7,
        #                                         batch_first=True,
        #                                         dropout= 0.2)
        self.decoder = nn.Linear(int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size),self.config.hidden_size)
        self.W_outputs = nn.ModuleList()
        for i in range(self.num_res):
            self.W_outputs.append(nn.Linear(config.reservoir_size[i], self.config.output_size).float())

    class SequenceAdjuster(nn.Module):
        def __init__(self, input_seq_len, output_seq_len, feature_dim):
            super().__init__()
            # 只需要序列调整层
            self.seq_adjuster = nn.Linear(input_seq_len, output_seq_len)

        def forward(self, x):
            # 输入形状: [batch, features, seq_in]
            # 序列长度调整: [batch, features, seq_out]
            x = self.seq_adjuster(x)

            # 转置回原始维度顺序: [batch, seq_out, features]
            return x.transpose(1, 2)
    class Time_feature_encoder(nn.Module):
        def __init__(self, period_type='hour'):
            """
            序列时间特征周期性编码器

            参数:
            period_type : str, 默认为'hour'
                周期类型，可选值:
                    'minute' - 需要额外输入分钟数据（当前不支持）
                    'hour' - 包含小时、星期、月份、天及交互特征
                    'day' - 包含天、月份特征
                    'week' - 包含星期特征
                    'month' - 包含月份特征
            """
            super().__init__()
            self.period_type = period_type
        def forward(self, time_features):
            """
            根据指定的周期类型编码时间特征

            参数:
            timestamps: 时间戳列表(datetime对象)
            period_type: 周期类型，可选值:
                'hour' - 包含小时、星期及交互特征
                'week' - 仅包含星期特征
                'minute' - 包含分钟、小时、星期及交互特征

            返回:
            numpy数组: 周期特征矩阵
            """
            #if time_features.dim() != 3 or time_features.size(-1) != 5:
            #    raise ValueError("输入特征应为形状(batch_size, seq_len, 5): [minute, hour, dayofweek, day, month]")
            #print("time_features",time_features.shape)
            batch_size, seq_len, _ = time_features.shape
            flat_features = time_features.view(-1, 5)
            # 提取所有时间成分
            minutes = flat_features[:,0].float()
            hours = flat_features[:,1].float()
            dayofweek = flat_features[:,2].float()
            day = flat_features[:,3].float()
            month = flat_features[:,4].float()

            # 初始化特征列表
            features = []

            # 分钟周期编码 (60分钟)
            if self.period_type == 'minute':
                minute_angle = 2 * torch.pi * minutes / 60
                minute_sin, minute_cos = torch.sin(minute_angle), torch.cos(minute_angle)
                features.extend([minute_sin, minute_cos])

                hour_angle = 2 * torch.pi * hours / 24
                hour_sin = torch.sin(hour_angle)
                hour_cos = torch.cos(hour_angle)
                features.extend([hour_sin, hour_cos])

                weekday_angle = 2 * torch.pi * dayofweek / 7
                weekday_sin = torch.sin(weekday_angle)
                weekday_cos = torch.cos(weekday_angle)
                features.extend([weekday_sin, weekday_cos])

            # 小时周期编码 (24小时)
            if self.period_type == "hour":
                hour_angle = 2 * torch.pi * hours / 24
                hour_sin, hour_cos = torch.sin(hour_angle), torch.cos(hour_angle)
                features.extend([hour_sin, hour_cos])

                weekday_angle = 2 * torch.pi * dayofweek / 7
                weekday_sin = torch.sin(weekday_angle)
                weekday_cos = torch.cos(weekday_angle)
                features.extend([weekday_sin, weekday_cos])

            # 星期周期编码 (7天)
            if self.period_type == "day":
                weekday_angle = 2 * torch.pi * dayofweek / 7
                weekday_sin = torch.sin(weekday_angle)
                weekday_cos = torch.cos(weekday_angle)
                features.extend([weekday_sin, weekday_cos])
                month_angle = 2 * torch.pi * month / 12
                month_sin = torch.sin(month_angle)
                month_cos = torch.cos(month_angle)
                features.extend([month_sin, month_cos])

            # 转换为特征矩阵

            if features:
                cyclic_flat = torch.stack(features, dim=1)
                #print("cyclic_flat",cyclic_flat.shape)
                cyclic_features = cyclic_flat.view(batch_size, seq_len, -1)
                #print("multi_cycle_encoding outputs:",cyclic_features.shape)
            else:
                cyclic_features = torch.zeros(batch_size, seq_len, 0, device=time_features.device)

            return cyclic_features

    def forward(self,
            inputs_embeds: Optional[torch.Tensor] = None,
            reservoir_outputs: list = None,
            x_marks: Optional[torch.Tensor] = None,
            y_marks: Optional[torch.Tensor] = None,
            labels_ids: Optional[torch.Tensor] = None,):

        labels = labels_ids
        x_marks = self.time_feature_encoder(x_marks)
        #inputs_embeds = torch.cat([inputs_embeds, x_marks], dim=-1)
        attn_output = self.EmbeddingModel(inputs_embeds)

        #print(x_marks.shape)


        reservoir_outputs_fl = []
        for output,W_out in zip(reservoir_outputs, self.W_outputs):
            #print(output.shape)
            output = W_out(output)
            reservoir_outputs_fl.append(output)

        reservoir_outputs = torch.stack(reservoir_outputs_fl, dim = 2)
        reservoir_outputs = reservoir_outputs.permute(0,2,1,3)
        print("reservoir_outputs before combine", reservoir_outputs.shape)
        stacked = reservoir_outputs
        B, N, T, D = stacked.shape
        pooled = stacked.mean(dim=2)
        scores = self.score_mlp(pooled.view(B * N, D)).view(B, N)
        weights = F.softmax(scores, dim=1)
        weights = weights.unsqueeze(-1).unsqueeze(-1)
        weighted = (stacked * weights).sum(dim=1)
        reservoir_outputs = weighted
        print("reservoir_outputs after combine", reservoir_outputs.shape)
        #self_attn_output_0 = self.self_attention_0(reservoir_outputs,reservoir_outputs,reservoir_outputs)[0]
        #reservoir_outputs = self.self_attnetion_layer_norm_0(reservoir_outputs+self.dropout(self_attn_output_0))
        #self_attn_output_1 = self.self_attention_1(reservoir_outputs,reservoir_outputs,reservoir_outputs)[0]
        #reservoir_outputs = self.self_attnetion_layer_norm_1(reservoir_outputs+self.dropout(self_attn_output_1))
        #attn_output,_ = self.self_attention_0(attn_output, attn_output, attn_output)
        #attn_output,_ = self.self_attention_1(attn_output, attn_output, attn_output)
        #attn_output = self.shape_shift(attn_output)
        print("attn_output.shape",attn_output.shape)
        print("reservoir_outputs.shape",reservoir_outputs.shape)

        for i in range(self.cross_attn_layers):
            # Cross-attention计算
            attn_layer_output, _ = self.cross_attns[i](
                query=attn_output,
                key=reservoir_outputs,
                value=reservoir_outputs
            )

            # 残差连接 + LayerNorm
            attn_output = attn_output + self.dropout(attn_layer_output)
            attn_output = self.cross_attn_norms[i](attn_output)
        # Apply cross-attention between inputs and reservoir outputs
        #attn_output, attn_weight = self.crossattn(inputs_embeds.float(), reservoir_outputs.float(), reservoir_outputs.float())
        #attn_output = self.input_projection(attn_output.float())
        #print("attention_input.shape",attn_output.shape)
        #attn_output = self.norm1(inputs_embeds.float() + self.dropout(attn_output))
        #print(reservoir_outputs.shape)
        #print(attn_output.permute(0,2,1).shape)
        attn_output = self.feed_forward(attn_output.permute(0,2,1))
        #attn_output = self.norm2(inputs_embeds.float() + self.dropout(ff_output))
        #prediction = outputs["prediction_outputs"].float()
        prediction = attn_output
        #print("attn_output.shape",attn_output.shape)
        #prediction = self.decoder(attn_output)
        #print("prediction.shape",prediction.shape)
        # Feed the attention output into the BERT model
        loss = None
        if labels_ids is not None:
            labels_ids = labels_ids.float()  # 确保标签为 float 类型

            # 创建 Huber Loss 函数
            huber_loss_fn = torch.nn.HuberLoss(reduction='mean', delta=1.0)

            # 计算 Huber Loss
            huber_loss = huber_loss_fn(prediction, labels_ids)

            # 可选：添加自适应 delta 调整
            with torch.no_grad():
                # 计算预测误差的绝对中位数
                abs_errors = torch.abs(prediction.detach() - labels_ids)
                median_abs_error = torch.median(abs_errors)

                # 基于误差分布动态调整 delta
                adaptive_delta = torch.clamp(median_abs_error * 0.5, min=0.5, max=5.0)
                huber_loss_fn.delta = adaptive_delta.item()

            # 保留 MAE 作为评估指标
            mae_loss = F.l1_loss(prediction, labels_ids)
            mse_loss = F.mse_loss(prediction, labels_ids)

            # 记录损失到 wandb（添加自适应 delta 值）
            wandb.log({
                "huber_loss": huber_loss,
                "mae_loss": mae_loss,
                "mse_loss": mse_loss,
                "huber_delta": adaptive_delta.item(),
                "median_abs_error": median_abs_error.item()
            })

            # 使用 Huber Loss 作为主损失
            loss = huber_loss


        #projected_logits = self.output_projection(outputs.logits)
        #all_logits.append(projected_logits)

        # Concatenate all batch outputs to form the final output tensors
        #final_logits = torch.cat(all_logits, dim=0)  # Concatenate logits along batch dimension

        # Compute final loss if labels are provided (average over batches)
        #if labels is not None:
        #    final_loss = torch.mean(torch.stack(all_losses))
        #print(torch.cat(input_aft_attn, dim=0).shape)
        # Return in BERT-like output format
        if labels is not None:
            return {"loss": loss.float(),
                    "mae_loss": mae_loss.float(),
                    "mse_loss": mse_loss.float(),
                    "prediction_outputs": prediction.float(),
                    #"loc": outputs["loc"].float(),
                    #"scale": outputs["scale"].float(),
                    }
        else:
            return {"prediction_outputs": prediction.float(),
                    #"loc": outputs["loc"].float(),
                    #"scale": outputs["scale"].float(),
                    #"cros_attn_weights":attn_weight
                    }
def extract_inputs_and_labels(dataset):

    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    inputs_embeds_list = []
    labels_ids_list = []
    x_marks_list = []
    y_marks_list = []

    for batch in tqdm(loader, desc="Training Batches"):
        #print(batch["labels_ids"].shape)
        inputs_embeds_list.append(batch['inputs_embeds'])
        x_marks_list.append(batch['x_marks'])
        y_marks_list.append(batch['y_marks'])
        labels_ids_list.append(batch['labels_ids'])

    return {"inputs_embeds": torch.cat(inputs_embeds_list, dim=0),
            "x_marks": torch.cat(x_marks_list, dim=0),
            "labels_ids": torch.cat(labels_ids_list, dim=0),
            "y_marks": torch.cat(y_marks_list, dim=0),}


#from time_data_normalize import Dataset_ETT_hour
import numpy as np
# prepare data for lstm
from sklearn.preprocessing import StandardScaler
from pandas import read_csv
from pandas import DataFrame
import random
from sklearn.model_selection import train_test_split
from pandas import concat
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset

if __name__ == "__main__":


    dataset= read_csv('ETTh1.csv')
    print(dataset.columns)
    dataset['date'] = pd.to_datetime(dataset['date'])
    dataset['minute'] = dataset['date'].dt.minute
    dataset['hour'] = dataset['date'].dt.hour
    dataset['dayofweek'] = dataset['date'].dt.dayofweek
    dataset['day'] = dataset['date'].dt.day
    dataset['month'] = dataset['date'].dt.month


    time_features = dataset[['minute','hour', 'dayofweek', 'day', 'month']].values
    dataset = dataset.drop(['minute','hour', 'dayofweek', 'day', 'month'], axis = 1)
    dataset=dataset.dropna()
    dataset = dataset.drop(['date'], axis = 1)
    dataset = dataset.dropna()


    y = dataset.OT.values


    X = dataset.values

    scaler = StandardScaler()
    X = scaler.fit_transform(X)



    #X=X[1:]

    #Reservoir_id = np.array([[0] * len(X[0])] + X[:-1].tolist())
    # Create a zero column of shape (100, 1)
    '''
    zero_col = np.zeros((X.shape[0], 1))

    # Concatenate the original array with the zero column along the second axis (columns)
    X = np.hstack((X, zero_col))
    #X =  dataset.drop(['ate'], axis = 1).values

    #X_train, X_test, y_train, y_test =train_test_split(X.values, y, test_size=0.2, shuffle=False)
    '''

    from tqdm.auto import tqdm
    # 1. Preprocess the data into the required format
    def create_sequences(data,time_features, seq_length, pred_length):
        sequences = []
        seq_x_time = []
        targets = []
        seq_y_time = []
        for i in tqdm(range(len(data) - seq_length - pred_length + 1)):
            sequences.append(data[i:i+seq_length])
            seq_x_time.append(time_features[i:i+seq_length])
            targets.append(data[i+seq_length:i+seq_length+pred_length])
            seq_y_time.append(time_features[i+seq_length:i+seq_length+pred_length])
        return torch.tensor(sequences), torch.tensor(seq_x_time), torch.tensor(targets), torch.tensor(seq_y_time)

    X,x_marks,y,y_marks = create_sequences(X, time_features,seq_length=configuration.sequence_length, pred_length=configuration.pred_len)# Zeros tensor of shape [16941, 384, 1]
    # print(X.shape)
    #zeros = torch.zeros((X.size(0), X.size(1), 9), dtype=X.dtype)

    # Concatenate along the last dimension
    #X = torch.cat((X, zeros), dim=-1)


    batch=100
    indices = np.arange(len(X))
    barrier = int(len(indices)/batch)*batch
    indices = indices[0:barrier]
    soft_border = int((configuration.sequence_length/batch))+8

    indices = [indices[i:i+batch] for i in range(0, len(indices), batch)]

    border1 = int(len(indices)*0.7)
    border2 = border1+int(len(indices)*0.1)
    border3 = border2+int(len(indices)*0.2)

    train_ind = indices[0:border1]
    val_ind = indices[border1-soft_border: border2]
    test_ind = indices[border2-soft_border: border3]

    # random.shuffle(train_ind)
    # random.shuffle(val_ind)
    #random.shuffle(test_ind)


    X_train = [X[item] for sublist in train_ind for item in sublist]
    x_marks_train = [x_marks[item] for sublist in train_ind for item in sublist]
    y_train = [y[item] for sublist in train_ind for item in sublist]
    y_marks_train = [y_marks[item] for sublist in train_ind for item in sublist]

    X_val = [X[item] for sublist in val_ind for item in sublist]
    x_marks_val = [x_marks[item] for sublist in val_ind for item in sublist]
    y_val = [y[item] for sublist in val_ind for item in sublist]
    y_marks_val = [y_marks[item] for sublist in val_ind for item in sublist]

    X_test = [X[item] for sublist in test_ind for item in sublist]
    x_marks_test = [x_marks[item] for sublist in test_ind for item in sublist]
    y_test = [y[item] for sublist in test_ind for item in sublist]
    y_marks_test = [y_marks[item] for sublist in test_ind for item in sublist]

#train_indices, test_indices =train_test_split(indices,  test_size=0.2, shuffle=False)
#indices = [item for sublist in indices for item in sublist]

import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, tokenized_inputs,  x_marks = None, y_marks = None, labels=None, pos=None):
        self.tokenized_inputs = tokenized_inputs
        self.x_marks = x_marks
        self.y_marks = y_marks
        self.labels = labels
        self.pos = pos
        self.id_list = None
        self.re = None

    def __len__(self):
        return len(self.tokenized_inputs)

    def __getitem__(self, idx):
        if self.labels is not None:
            return {
                "inputs_embeds": torch.tensor(self.tokenized_inputs[idx]).float(),
                "x_marks": torch.tensor(self.x_marks[idx]).float(),
                "labels_ids": torch.tensor(self.labels[idx]).float(),
                "y_marks": torch.tensor(self.y_marks[idx]).float(),
                #"id": torch.tensor(self.id_list[idx]),  # Include the id directly
                #"reservoir_ids": torch.tensor(self.re[idx]),
            }
        else:
            return {
                "inputs_embeds": torch.tensor(self.tokenized_inputs[idx]).float(),
                "x_marks": torch.tensor(self.x_marks[idx]).float(),
                "y_marks": torch.tensor(self.y_marks[idx]).float(),
            }


# Assuming you have X_train, y_train, X_test, y_test, trainpos, and testpos defined


if __name__ == "__main__":
    # print(X_train[0], flush=True)
    train_dataset = CustomDataset(X_train, x_marks_train, y_marks_train, y_train)
    # print(train_dataset[0], flush=True)

    val_dataset = CustomDataset(X_val, x_marks_val, y_marks_val, y_val)

    test_dataset = CustomDataset(X_test, x_marks_test, y_marks_test, y_test)

    preprocess = ReservoirTTimeSeries(configuration)
    train_dataset_dic = extract_inputs_and_labels(train_dataset)
    #val_dataset_dic = extract_inputs_and_labels(val_dataset)
    test_dataset_dic = extract_inputs_and_labels(val_dataset)
    #preprocess(inputs_embeds = train_dataset_dic["inputs_embeds"],labels_ids = train_dataset_dic["labels_ids"])

    train_dataset = CustomDataset(X_train,x_marks_train,y_marks_train,y_train)
    # print(train_dataset[0], flush=True)
    #print("train_dataset[0][labels_ids].shape",train_dataset[0]["labels_ids"].shape)

    val_dataset = CustomDataset(X_val,x_marks_val, y_marks_val, y_val)

    test_dataset = CustomDataset(X_test,x_marks_test, y_marks_test, y_test)

    preprocess = ReservoirTTimeSeries(configuration)
    train_dataset_dic = extract_inputs_and_labels(train_dataset)
    #val_dataset_dic = extract_inputs_and_labels(val_dataset)
    test_dataset_dic = extract_inputs_and_labels(val_dataset)
    #preprocess(inputs_embeds = train_dataset_dic["inputs_embeds"],labels_ids = train_dataset_dic["labels_ids"])

    from datasets import Dataset
    train_dataset_fl = Dataset.from_dict(preprocess(inputs_embeds = train_dataset_dic["inputs_embeds"],
                                                        x_marks = train_dataset_dic["x_marks"],
                                                        y_marks = train_dataset_dic["y_marks"],
                                                        labels_ids = train_dataset_dic["labels_ids"]))
    train_dataset_fl.set_format(type='torch')
    test_dataset_fl = Dataset.from_dict(preprocess(inputs_embeds = test_dataset_dic["inputs_embeds"],
                                                        x_marks = test_dataset_dic["x_marks"],
                                                        y_marks = test_dataset_dic["y_marks"],
                                                        labels_ids = test_dataset_dic["labels_ids"],
                                                        dataset_type = "eval_dataset",
                                                        train_dataset = train_dataset_dic))
    test_dataset_fl.set_format(type='torch')
    #print("train_dataset_fl input_embs",train_dataset_fl["inputs_embeds"][0].shape)
    #print("train_dataset_fl reservoir_outputs",train_dataset_fl["reservoir_outputs"][0].shape)


#embedding_model = TimeSeriesEmbedding(configuration)
#reservoir_model = ReservoirTTimeSeries(configuration)
#fl_model = Reservoir_fl_model(configuration)
#dataloader = DataLoader(train_dataset,batch_size=64,shuffle = False)

#for batch in dataloader:
#    inputs_embeds = batch["inputs_embeds"]
#    label_ids = batch["labels_ids"]
#    inputs_embeds = embedding_model(inputs_embeds)
#    #print(inputs_embeds.shape)
#    #reservoir_output,reservoir_state = reservoir_model(inputs_embeds = inputs_embeds)
#    result = fl_model(inputs_embeds = inputs_embeds)
#    break

import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error

def compute_metrics(eval_pred):

    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[-1]
        predictions = np.array(predictions)
        labels = np.array(labels)

    if not isinstance(predictions, torch.Tensor):
        predictions = torch.tensor(predictions)
    if not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels)
    # 如果模型返回字典
    mae = F.l1_loss(predictions, labels)
    mse = F.mse_loss(predictions, labels)

    return {
        "MSE": mse,
        "MAE": mae
    }

training_args = TrainingArguments(
    output_dir="./checkpoint/patchtst/ETTh1/pretrain/last_hope_output/",
    overwrite_output_dir=True,
    learning_rate=0.002,
    num_train_epochs=100,
    do_eval=True,
    eval_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir="./checkpoint/patchtst/ETTh1/pretrain/logs/",  # Make sure to specify a logging directory
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    label_names=["labels_ids"],
    report_to="wandb",  # Enable logging to Weights & Biases
)

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=20,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)
#print(train_dataset[0])
#print(train_dataset[0].keys())



class ReservoirTrainer(Trainer):
    def get_train_dataloader(self) -> DataLoader:
       train_dataset = self.train_dataset
       return DataLoader(train_dataset, shuffle=True, batch_size=16)

    def get_eval_dataloader(self, eval_dataset=None) -> DataLoader:
       if eval_dataset is None:
           eval_dataset = self.eval_dataset
       return DataLoader(eval_dataset, shuffle=True, batch_size=16)

    def get_test_dataloader(self, test_dataset=None) -> DataLoader:
       if test_dataset is None:
           test_dataset = self.test_dataset
       return DataLoader(test_dataset, shuffle=True, batch_size=16)

model  = Reservoir_fl_model(configuration)
if __name__ == "__main__":
    trainer = ReservoirTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset_fl,
        eval_dataset=test_dataset_fl,
        # callbacks=[early_stopping_callback],
        compute_metrics=compute_metrics,
    )
#print(type(train_dataset_fl["labels_ids"]))
# pretrain
    trainer.train()
# Training loop

#res_state = torch.zeros(1, 1000)
#for batch in train_loader:
#    inputs_embeds = batch['inputs_embeds']  # Extract input sequences from the batch
#    labels_ids = batch['labels_ids']        # Extract target sequences from the batch#

    # Forward pass through the DeepReservoirNet
#    reservoir_outputs = model(inputs_embeds=inputs_embeds)

    # Get the model's outputs and updated reservoir state
    #outputs = output_dict['Output']
    #res_state = output_dict['State']
#    print(reservoir_outputs.shape) #the output shape is (batch_size,output_size,num_features)
#    print(reservoir_outputs)
#    break  #next step is to keep track of all Reservoir states across all batches
           #next step is to use the cross attetnion to combine input and reservoir_outputs

# Step 4: Forward pass through the model
#output_dict = model(train_dataset)
#print(res_state.shape)